Skip to content

WIP [ilu/ttx] optimize quant_moe_export#343

Draft
songliucsu wants to merge 1 commit into
masterfrom
optimize_quant_moe
Draft

WIP [ilu/ttx] optimize quant_moe_export#343
songliucsu wants to merge 1 commit into
masterfrom
optimize_quant_moe

Conversation

@songliucsu

Copy link
Copy Markdown
Collaborator
  1. Rank-1 accumulation → matrix-engine tl.dot The int8 GEMM previously fell back to a rank-1 outer-product accumulation because ILU's int8 tl.dot miscompiles (invalid bitcast → segfault). Now the int8 operands (|v| ≤ 127) are losslessly cast to fp16 and fed to an fp16 MMA; each BLOCK_K tile is computed exactly in the fp32 dot output, cast back to int32 and accumulated, then dequantized per group. This uses the matrix engine while staying bit-exact (autotune configs gained a BLOCK_K dim). ~6x speedup.

  2. Deduplicated host→device syncs The four sub-steps (input quant / up-GEMM-swiglu / requant / down-GEMM) each recomputed group_offsets and called .max().item() — 4 device→host syncs that serialized the launches. Now _make_group_offsets() computes them once at the entry and threads them through all steps, leaving a single sync. ~28% faster.

  3. Intermediate fc1_out fp32 → bf16 The SwiGLU output is stored as bf16 (matching ixformer's up-GEMM output) instead of fp32, halving the intermediate round-trip traffic. Accuracy-neutral; perf-neutral at the benchmarked (launch-bound) sizes but beneficial at larger batches and lower memory footprint.

Cumulative (24 experts, top_k=4, hidden=512, inter=1280, 97 tokens): int8 ~9.6 → 1.10 ms (~8.7x), int4 ~10.4 → 1.13 ms (~9.2x). All verified bit-accurate for both int8 and int4.

@github-actions

github-actions Bot commented Jun 3, 2026

Copy link
Copy Markdown

Claude Code Review

Verdict: Comment -- Switch from rank-1 int32 accum to fp16 tl.dot plus shared group-offset reuse looks reasonable, but there are correctness/perf concerns around BLOCK_K vs QUANT_GROUP_SIZE and the fc1 dtype change.

Summary

The PR replaces the per-K rank-1 int8 accumulation with a tiled fp16 tl.dot (int8 cast losslessly to fp16, dequantized per quant-group), adds BLOCK_K to the autotuner, and hoists the per-batch group_offsets/max_m computation so the single device->host sync is shared across all four launches in quant_moe_experts_impl. It also drops the fp1 intermediate from fp32 to the model dtype (bf16).

Must fix

  • [BLOCKER] BLOCK_K may exceed QUANT_GROUP_SIZE -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:118-138 -- Autotune offers BLOCK_K in {32,64,128} but if quant_group_size < BLOCK_K (e.g. group size 32/64), K_TILES_PER_GROUP == 1 and the single tile loads BLOCK_K columns; the k_in_group < QUANT_GROUP_SIZE mask zeros the out-of-group lanes correctly, but the loop only iterates groups by kg * QUANT_GROUP_SIZE strides, so every quant group still only reads its own QUANT_GROUP_SIZE lanes -- this is fine numerically, but assert/skip configs where BLOCK_K > QUANT_GROUP_SIZE or document, since you waste up to 4x the K work and the autotuner will still pick them. Please confirm group sizes used in production and either constrain configs or filter in _quant_moe_autotune_configs.

Suggestions

Suggestions (3)
  • [MAJOR] fc1_out dtype change is a precision regression vs prior contract -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:537-540 -- The old comment explicitly said "core keeps activated in fp32 before passing to down_proj_quantize"; switching to bf16 changes numerical output and the justification ("clamps to ~7 bits anyway") ignores that the dynamic per-token scale is computed from this tensor's amax, so bf16 rounding shifts the scale. Worth a unit/parity test against fp32 reference before landing.
  • [MAJOR] fp16 dot with int8 magnitudes up to 127 and BLOCK_K=128 -- moe_quant_experts.py:130 -- Partial-sum bound is 128 * 127^2 ~= 2.06e6, well under fp16's 2^11 mantissa only when summed in fp32 accumulator; per-element products 127*127=16129 exceed fp16's exactly-representable integer range (2^11=2048), so individual lane products are not exact in fp16 inputs -- but tl.dot with out_dtype=fp32 accumulates in fp32 from fp16 multiplicands. Verify the multiplicand cast int8 -> fp16 is exact (it is for |v|<=2048, so yes) and that the ILU MMA actually does fp32 accumulate; if it accumulates in fp16 internally the result is wrong.
  • [MINOR] tl.trans(b).to(fp16) order -- moe_quant_experts.py:130,138 -- Casting before transpose (b.to(fp16) then tl.trans) typically produces better layouts on Triton; current order transposes int8 then casts. Likely codegen-equivalent but worth checking generated PTX/LLIR.

Notes

  • [CHECK] K_TILES_PER_GROUP is constexpr only if QUANT_GROUP_SIZE and BLOCK_K are both constexpr at this point -- confirm both are passed as tl.constexpr (QUANT_GROUP_SIZE is, BLOCK_K is); otherwise the range(K_TILES_PER_GROUP) won't unroll.
  • [CHECK] Reused group_offsets tensor is now read by four kernels in sequence on the same stream -- fine for in-order execution, but if any launch becomes async across streams in the future this becomes a hazard.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Triton int8 grouped matmul kernel to use matrix engine-based tl.dot by losslessly casting int8 operands to fp16, avoiding previous compiler bugs. It also introduces a performance optimization that computes group offsets and maximum token counts once and shares them across all kernel launches to minimize device-to-host synchronization overhead. Additionally, the intermediate fc1_out tensor is changed from float32 to the input precision to reduce memory bandwidth. Feedback suggests using torch.zeros instead of torch.empty followed by a manual zero-write to avoid an unnecessary host-to-device scalar write.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py Outdated
@songliucsu songliucsu force-pushed the optimize_quant_moe branch from 0dddddc to 4204d31 Compare June 3, 2026 07:49
@github-actions

github-actions Bot commented Jun 3, 2026

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Int8->fp16 dot path can lose precision when BLOCK_K exceeds the quant group size, and the autotuner key does not include QUANT_GROUP_SIZE so the prune hook may be ineffective.

Summary

Replaces the rank-1 int32 accumulation in the grouped int8 MoE GEMM with an fp16 tl.dot (cast losslessly from int8) that accumulates int32 per quant group, and threads pre-computed group offsets / max_m through the four kernel launches to avoid repeated device-host syncs.

Must fix

  • [BLOCKER] fp16 dot precision claim only holds when BLOCK_K <= QUANT_GROUP_SIZE -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:155-165 -- The comment argues exactness because each tile partial sum is bounded by BLOCK_K * 127^2 < 2^24, but the prune hook only fires when QUANT_GROUP_SIZE is in named_args/kwargs. QUANT_GROUP_SIZE is not in the autotune key and is passed positionally as a constexpr, so named_args.get("QUANT_GROUP_SIZE") will typically return None and no configs are pruned -- a BLOCK_K=128 config can be selected for quant_group_size=32, masking 3/4 of lanes (still correct) but more importantly the rationale for skipping per-tile dequant assumes BLOCK_K <= group size. Verify the prune hook actually receives QUANT_GROUP_SIZE (Triton passes constexprs in named_args only if listed) and add it to key= so the autotuner re-runs per group size; otherwise enforce BLOCK_K <= QUANT_GROUP_SIZE at launch.
  • [BLOCKER] fp16 mantissa cannot represent all int32 partial sums up to 2^24 -- moe_quant_experts.py:154-160 -- tl.dot(a_f16, b_f16, out_dtype=tl.float32) accumulates in fp32, but the inputs are fp16; a single int8*int8 product up to 16129 is exact in fp16, however the fp32 accumulator sum of up to 128 such terms can exceed 2^24 (128 * 16129 ~= 2.06e6, fine), yet the cast tile.to(tl.int32) truncates rather than rounds. If any intermediate fp32 value is e.g. 1234.9999 due to a non-exact fp16 input bit pattern (it should not happen for |v|<=127, but worth asserting), truncation gives a wrong int. Use tl.extra.libdevice.rint or (tile + 0.5*sign).to(tl.int32) before casting, or accumulate via a path that returns int32 directly.

Suggestions

Suggestions (3)
  • [MAJOR] Autotune key missing QUANT_GROUP_SIZE -- moe_quant_experts.py:65 -- Different quant_group_size values produce different optimal BLOCK_K; without it in key, the cached choice from the first shape will be reused incorrectly.
  • [MAJOR] _prune_block_k_gt_group lookup is fragile -- moe_quant_experts.py:48-58 -- Triton passes constexprs to the prune hook through named_args keyed by parameter name; confirm with a unit test that pruning actually drops oversized configs (log len(kept) once) -- silent failure here costs perf, not correctness.
  • [MINOR] out_dtype=tl.float32 on fp16 dot -- moe_quant_experts.py:158 -- On ILU matrix engine, verify fp16->fp32 accumulation is supported and not silently downgraded to fp16 accumulation; a downgrade would break the exactness argument.

Notes

  • [CHECK] The new path doubles the loads of a per quant-group iteration only if K_TILES_PER_GROUP > 1; with the prune hook intended to set BLOCK_K >= QUANT_GROUP_SIZE-matching, ensure K_TILES_PER_GROUP == 1 is the common case, otherwise the int32 partial accumulation across tiles still holds but the comment ("each BLOCK_K tile ... computed exactly") may mislead.
  • [CHECK] tl.trans(b).to(tl.float16) -- confirm trans-then-cast is no worse than cast-then-trans on ILU; some backends prefer the latter for layout reasons.

@songliucsu songliucsu force-pushed the optimize_quant_moe branch from 4204d31 to b5946fa Compare June 3, 2026 08:00
@github-actions

github-actions Bot commented Jun 3, 2026

Copy link
Copy Markdown

Claude Code Review

Verdict: Comment -- Switches int8 grouped GEMM from rank-1 int32 accumulation to fp16 tl.dot with int32 partials, and threads shared group offsets through the pipeline; correctness hinges on a couple of fp16/tile-size assumptions worth a closer look.

Summary

The PR replaces the inner per-K rank-1 int32 accumulation in _quant_moe_gemm_kernel with a tiled tl.dot over int8 operands cast to fp16, accumulating each BLOCK_K tile into an int32 partial that is then dequantized per quant group. It also factors the group-offset / max-M computation into a single helper so the four kernel launches in quant_moe_experts_impl share one device->host sync.

Must fix

  • [BLOCKER] fp16 cast of int8 is not lossless for |v| > 2048 in products -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:165-170 -- The header claims fp16 MMA is "exact" for int8 inputs, but tl.dot(a_f16, b_f16, out_dtype=fp32) typically multiplies in fp16 before fp32 accumulation; an int8*int8 product up to 16129 is representable in fp16, but the running fp16 partial within the dot can lose precision once it exceeds 2048. The rounding-then-int32 step does not recover bits already lost inside the dot. Please verify on ILU that the dot truly performs fp32 multiply-accumulate (not fp16 multiply, fp32 accumulate), or switch to a known-exact path (e.g. accumulate fp32 and skip the int32 round-trip, or split into two int8 halves).

Suggestions

Suggestions (3)
  • [MAJOR] Round-to-int32 cost and necessity -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:170,178 -- If the dot is truly exact for tile sums <= BLOCK_K*127^2 < 2^24, tile.to(tl.int32) already truncates an exact integer; the +/- 0.5 is only needed if the dot is inexact, in which case it is not sufficient (see blocker). Either drop the rounding or document the precise drift bound it is correcting.
  • [MAJOR] Autotune key includes QUANT_GROUP_SIZE but prune reads it from named_args -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:54-62 -- QUANT_GROUP_SIZE is a tl.constexpr passed positionally; confirm it shows up in named_args for the early_config_prune callback on this Triton version, otherwise the prune silently no-ops and the autotuner can pick BLOCK_K > group size configs (correct but wasteful).
  • [MINOR] or short-circuit on QUANT_GROUP_SIZE -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:60 -- named_args.get(...) or kwargs.get(...) will fall through to kwargs when the named value is 0; the followup if not qgs or qgs <= 0 then returns all configs. Harmless given the launcher normalizes to >0, but the double-fallback is confusing -- prefer an explicit if "QUANT_GROUP_SIZE" in named_args.

Notes

  • [CHECK] _make_group_offsets returns int32 prefix sums; if tokens_per_expert.cumsum overflows int32 at very large batch * experts the assignment to group_offsets[1:] will wrap silently. The previous code had the same behavior, but worth confirming the upstream contract.
  • [CHECK] The shared max_m is now used as the autotune MAX_M key for both the fc1 (M=tokens) and fc2 launches; that matches prior behavior since both pass the same per-expert token count, but verify no caller passes a different size_per_group to the second matmul.

@songliucsu songliucsu force-pushed the optimize_quant_moe branch from b5946fa to df91212 Compare June 3, 2026 08:05
@github-actions

github-actions Bot commented Jun 3, 2026

Copy link
Copy Markdown

Claude Code Review

Verdict: Comment -- Switch from scalar int32 outer-product to fp16 tl.dot with int32 round-trip is plausible but has a few correctness/perf concerns worth addressing.

Summary

Replaces the rank-1 int32 accumulator loop in the grouped int8 MoE GEMM with a tiled tl.dot that casts int8 operands to fp16, rounds the fp32 result back to int32 per BLOCK_K tile, and accumulates per quant group. Also adds a BLOCK_K autotune dimension with a pruner, and hoists group_offsets/max_m computation to a single device->host sync per quant_moe_experts_impl call.

Must fix

  • [BLOCKER] BLOCK_K not constrained to be a multiple it can divide groups cleanly when QUANT_GROUP_SIZE % BLOCK_K != 0 -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:158-166 -- The pruner only enforces BLOCK_K <= QUANT_GROUP_SIZE, but allows e.g. BLOCK_K=64, QUANT_GROUP_SIZE=128 (fine) and also BLOCK_K=128, QUANT_GROUP_SIZE=64 is filtered, but does not require divisibility. With QUANT_GROUP_SIZE=96 and BLOCK_K=64, the second tile loads k 64..127 with k_in_group < 96 masking the upper half — correct, but if QUANT_GROUP_SIZE is not a tl.constexpr-friendly multiple, K_TILES_PER_GROUP ceil-div still works. Verify all production quant_group_size values are powers of two / multiples of available BLOCK_Ks, or add a divisibility filter to avoid wasted MMA work and surprises.

Suggestions

Suggestions (4)
  • [MAJOR] fp16 dot exactness claim is hardware-specific -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:13-21 -- The comment asserts ILU's fp16 tl.dot performs a true fp32 multiply-accumulate verified empirically. This is a strong assumption guarding integer correctness; consider adding a runtime/CI numerical test (e.g. random int8 matmul vs reference) gated on the ILU backend so a future driver/compiler change cannot silently break int correctness.
  • [MAJOR] Rounding of an already-exact fp32 value is unnecessary but harmless; the tl.where(... 0.5, -0.5) adds work on the hot path -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:177-178 -- If the dot output is truly fp32-exact integer-valued (as the header claims), .to(tl.int32) already round-trips losslessly. The defensive round costs an extra compare+select+add per element on every BLOCK_K tile in the inner loop. Either drop it (and rely on the exactness invariant) or document it as a belt-and-braces guard.
  • [MAJOR] mask_n used for B but offs_n_up < N recomputed for B_up -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:184-186 -- Minor consistency: precompute mask_n_up = offs_n_up < N once outside the k loop, mirroring mask_n, to avoid recomputing the comparison every tile.
  • [MINOR] Autotune key includes QUANT_GROUP_SIZE but configs aren't actually parameterized by it -- mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py:170-174 -- Adding QUANT_GROUP_SIZE to the key forces re-tuning per group size, which is correct given the pruner depends on it, but it also multiplies cache entries. Confirm that production only uses 1–2 distinct group sizes.

Notes

  • [CHECK] The header comment claims "large-magnitude sums that exceed fp16 max of 65504 also come back finite/exact rather than inf" — this depends on Triton lowering int8->fp16 inputs into an fp32-accumulating MMA on ILU. Worth a one-off saturation test (e.g. all 127s, BLOCK_K=128) in the test suite to catch regressions.
  • [CHECK] _make_group_offsets returns torch.zeros(...) then assigns [1:] = cum; this is one extra kernel vs the prior torch.empty + explicit [0]=0. Negligible, but confirm there's no measurable launch overhead on small num_groups decode paths.

@songliucsu

Copy link
Copy Markdown
Collaborator Author

生产里 group size 只有 {128, 256, 320, 512} 和 K(128 的倍数),唯一不能被 BLOCK_K 整除的是 320 与 BK=128。
这种情况是正确的(尾部 lane 被掩码为 0,不跨组、不重复),int4 down=320 的精度用例已通过。
不加整除过滤器:autotune 按实测时间选,BK=128 只有真比其他快时才会被选中,强加约束反而可能删掉最优解。已在 docstring 写明。

@songliucsu songliucsu marked this pull request as draft June 3, 2026 08:40
@madengfei madengfei self-assigned this Jun 4, 2026
@madengfei madengfei changed the title [ilu/ttx] optimize quant_moe_export WIP [ilu/ttx] optimize quant_moe_export Jun 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants